Fix GridwiseGemmDlMultipleD element op for FloatAcc!=FloatC#3565
Closed
dbsanfte wants to merge 2 commits intoROCm:developfrom
Closed
Fix GridwiseGemmDlMultipleD element op for FloatAcc!=FloatC#3565dbsanfte wants to merge 2 commits intoROCm:developfrom
dbsanfte wants to merge 2 commits intoROCm:developfrom
Conversation
Author
|
I think there's still something wrong with the fused scaling as the cosine simularity to fp32 reference is still not matching a two-kernel approach (CK INT8 GEMM + separate scaling kernel). With CK + Fused scaling With Two-Kernel approach (CK GEMM kernel + Scaling kernel) Cosine simularity much better in the latter. |
Problem: - DeviceGemmDl crashes on gfx906 when K >= 1472 with small M (e.g., M=1) - Root cause: CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK was disabled - Without the offset trick, invalid buffer loads execute and crash before bounds checking can prevent them Solution: 1. Enable the OOB offset trick (0x80000000) so invalid coordinates safely return zero instead of accessing unmapped memory 2. Use full coordinate_has_valid_offset() check instead of the _assuming_visible_index_is_valid variant for proper K bounds validation Verified with INT8 GEMM tests: M=1 decode, K=14336, FFN projections all pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes element operation type mismatch in
GridwiseGemmDlMultipleD_km_kn_mnwhenFloatAcc != FloatC.The Bug
When accumulator type differs from output type (e.g., INT8×INT8 → INT32 accumulate → FP32 output), the CDE element op is invoked with references to the wrong storage type.
The element op contract is:
(E& e, const C& c, const D& d...)where:E=FloatC(the final output type, e.g.float)C=FloatAcc(the accumulator type, e.g.int32_t)Current behavior (broken): The kernel builds
dst_data_refsfromc_thread_buf, which isStaticBuffer<FloatAcc>. This means the element op receivesint32_t&for bothE&andC&—violating its signature whenFloatAcc != FloatC.Why this is wrong:
e = f(c, d...). Ifeandcalias the same storage, the conversion semantics are lost.float&for output fail at compile time.ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, FloatC, ...>onc_thread_buf—meaning it type-punsFloatAccbits asFloatC, which is undefined behavior for non-trivially-convertible types.Fixed behavior: Introduce separate
e_thread_buf<FloatC>for element op output, pass(E& e)from this buffer and(const C& c)fromc_thread_buf, then transfere_thread_bufto global memory.Context / Affected
DeviceGemmMultipleD_Dl)Minimal repro (verified)
This is a compile-time repro: no runtime execution or valid pointers needed.
Key point: the element op is intentionally non-templated and requires
float&for the output ref.If the kernel incorrectly passes an
int32_t&(FloatAcc) as the output ref, compilation fails.How to reproduce (compile)
On gfx906 + ROCm 7.1.1:
This exploits point
2.from Why this is wrong in the bug description to demonstrate the bug at compile time.Failing error line (upstream develop)
This is the first relevant error (line numbers may vary by commit):
The diagnostic also shows
dst_data_refscontainsint&(FloatAcc) where the element op requiresfloat&(FloatC).